{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin\n",
    "import sentencepiece as spm\n",
    "from glob import glob\n",
    "import os\n",
    "\n",
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.ms-en.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = [\n",
    "    '/home/husein/pure-text/dumping-parliament.txt',\n",
    "    '/home/husein/pure-text/filtered-dumping-wiki.txt',\n",
    "    '/home/husein/pure-text/filtered-dumping-academia.txt',\n",
    "    '/home/husein/pure-text/dumping-news.txt'\n",
    "]\n",
    "files.extend(glob('/home/husein/pure-text/the-pile/*.txt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'dumping-parliament.txt'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.path.split(files[0])[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ').replace('\\t', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/husein/pure-text/dumping-parliament.txt 277157\n",
      "/home/husein/pure-text/filtered-dumping-wiki.txt 2037602\n",
      "/home/husein/pure-text/filtered-dumping-academia.txt 4086649\n",
      "/home/husein/pure-text/dumping-news.txt 3781905\n",
      "/home/husein/pure-text/the-pile/00.jsonl-16.translated.txt 1885251\n",
      "/home/husein/pure-text/the-pile/00.jsonl-17.translated.txt 1901017\n",
      "/home/husein/pure-text/the-pile/00.jsonl-18.translated.txt 1846019\n",
      "/home/husein/pure-text/the-pile/00.jsonl-0.translated.txt 1111918\n",
      "/home/husein/pure-text/the-pile/00.jsonl-31.translated.txt 1903699\n",
      "/home/husein/pure-text/the-pile/00.jsonl-11.translated.txt 1172171\n",
      "/home/husein/pure-text/the-pile/00.jsonl-8.translated.txt 1642571\n",
      "/home/husein/pure-text/the-pile/00.jsonl-28.translated.txt 1830271\n",
      "/home/husein/pure-text/the-pile/00.jsonl-26.translated.txt 1810003\n",
      "/home/husein/pure-text/the-pile/00.jsonl-6.translated.txt 1555877\n",
      "/home/husein/pure-text/the-pile/00.jsonl-25.translated.txt 1901141\n",
      "/home/husein/pure-text/the-pile/00.jsonl-32.translated.txt 1868170\n",
      "/home/husein/pure-text/the-pile/00.jsonl-13.translated.txt 1082364\n",
      "/home/husein/pure-text/the-pile/00.jsonl-19.translated.txt 1907461\n",
      "/home/husein/pure-text/the-pile/00.jsonl-5.translated.txt 1644120\n",
      "/home/husein/pure-text/the-pile/00.jsonl-9.translated.txt 1696396\n",
      "/home/husein/pure-text/the-pile/00.jsonl-4.translated.txt 1127495\n",
      "/home/husein/pure-text/the-pile/00.jsonl-24.translated.txt 1863091\n",
      "/home/husein/pure-text/the-pile/00.jsonl-10.translated.txt 1193078\n",
      "/home/husein/pure-text/the-pile/00.jsonl-23.translated.txt 1910715\n",
      "/home/husein/pure-text/the-pile/00.jsonl-29.translated.txt 1881824\n",
      "/home/husein/pure-text/the-pile/00.jsonl-15.translated.txt 1873427\n",
      "/home/husein/pure-text/the-pile/00.jsonl-7.translated.txt 1641966\n",
      "/home/husein/pure-text/the-pile/00.jsonl-30.translated.txt 1877217\n",
      "/home/husein/pure-text/the-pile/00.jsonl-1.translated.txt 1056450\n",
      "/home/husein/pure-text/the-pile/00.jsonl-2.translated.txt 1958620\n",
      "/home/husein/pure-text/the-pile/00.jsonl-14.translated.txt 1900795\n",
      "/home/husein/pure-text/the-pile/00.jsonl-21.translated.txt 1909313\n",
      "/home/husein/pure-text/the-pile/00.jsonl-3.translated.txt 1274036\n",
      "/home/husein/pure-text/the-pile/00.jsonl-27.translated.txt 1868542\n",
      "/home/husein/pure-text/the-pile/00.jsonl-20.translated.txt 1968145\n",
      "/home/husein/pure-text/the-pile/00.jsonl-22.translated.txt 1808546\n",
      "/home/husein/pure-text/the-pile/00.jsonl-12.translated.txt 1122120\n"
     ]
    }
   ],
   "source": [
    "for file in files:\n",
    "    with open(file) as fopen:\n",
    "        data = list(filter(None, fopen.read().split('\\n')))\n",
    "    print(file, len(data))\n",
    "    s = os.path.split(file)[1]\n",
    "    filename = f'{s}.tsv'\n",
    "    with tf.io.gfile.GFile(filename, 'w') as outfile:\n",
    "        for i in range(len(data)):\n",
    "            c = cleaning(data[i])\n",
    "            if len(c) > 20:\n",
    "                outfile.write('%s\\t%s\\n' % ('', c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dumping_dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset(\n",
    "        ['dumping-parliament.txt.tsv', 'filtered-dumping-wiki.txt.tsv']\n",
    "    )\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['title', 'text'], ex)))\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/t5/models/mesh_transformer.py:210: UserWarning: get_sentencepiece_model_path is deprecated. Please pass the mixture or task vocabulary directly to the Mesh TensorFlow Transformer instead.\n",
      "  \"get_sentencepiece_model_path is deprecated. Please pass the mixture or \"\n"
     ]
    }
   ],
   "source": [
    "t5.data.TaskRegistry.remove('dumping_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'dumping_dataset',\n",
    "    dataset_fn = dumping_dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = functools.partial(\n",
    "        t5.data.preprocessors.rekey,\n",
    "        key_map = {'inputs': None, 'targets': 'text'},\n",
    "    ),\n",
    "    token_preprocessor = t5.data.preprocessors.unsupervised,\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    }
   ],
   "source": [
    "nq_task = t5.data.TaskRegistry.get(\"dumping_dataset\")\n",
    "ds = nq_task.get_dataset(split='qa.tsv', sequence_length={\"inputs\": 1024, \"targets\": 1024})\n",
    "r = tfds.as_numpy(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'inputs': array([   46, 32099,   331,   274,    29, 32098,  1123,   252,   167,\n",
       "           46,    17,  2627,    27,  9211,    29,   470,    47,    76,\n",
       "          112,   123,    16,    29,    51,    47,    76,   711,   123,\n",
       "           46, 11059,   331,   274,    29,  1014,   250,    31,    46,\n",
       "          167,    15,     7, 32097,   165,   611,  7476,     7,  5559,\n",
       "          639,    47,   124,  1008,    27,   635,    16,    53,   287,\n",
       "           17,    85,   567,    21,   178,  2155,   736,    28,   165,\n",
       "          109,   119,     7,   704,   171,    17,   779,  2336,  1930,\n",
       "          689,    17,    29, 32096,  4756,   308,   636,   610,   308,\n",
       "        32095,     3,  1439,    15,    12,  2742, 32094,  2742, 11504,\n",
       "          228, 29367,   326,  1133,    15,    12,   875,   288,  1299,\n",
       "          511,   875,   477,   228, 12114,   326,   571,   402,  4523,\n",
       "          249,  6214,    17,  5228,  2516,    32,   184,  2971,  8315,\n",
       "           17,    64,    18,    40,    16,  2374,  6214,    65,  4346,\n",
       "           22,   950,  1991,    17,   757,    64,   769,    15,     3,\n",
       "          223,   124, 32093,    17,   453,    18,  4547, 10933,   307,\n",
       "        32092, 11897,  4726,  2971,  8315,    15,     4,   491,  1760,\n",
       "         1382,    15,     5,     3,   313,  1760,  1382,    17,  2276,\n",
       "           35,  1239,  2923,   200,    59,   119,    17, 32091,  3797,\n",
       "          286, 11709,    25,  2447,    15,     4, 32090,     5,    16,\n",
       "          217,   774,  5792,    15,     4, 28098,    15,     5,     3,\n",
       "         9200,     7, 16628,    17,   852,   468,   313,  1760,  1382,\n",
       "           35,   560,  7897,   153,    16,   172,    15, 28098,    73,\n",
       "         7897,   172, 21344,  7851,    40,    15,     4,  3542,    15,\n",
       "            5,    41,  1281, 32089,   759,  9810,    43,  1239,   646,\n",
       "        21592,    15, 25677,    15,     3,   395,  1145, 13144, 32088,\n",
       "         1008,    27,  1749, 14416,  4547, 10933,   145,   355,    27,\n",
       "           29,    91,  5524,   145,   240,  1991,    17,    29,  2413,\n",
       "           43,   560,  2923,   200,    19,  1239,   646,  2837,    15,\n",
       "            3,   428,   700,    27,  3190,   725,  1465,     3,   310,\n",
       "          125,  6921,    17,  1604,    91,  1163,   318,   200, 32087,\n",
       "         4547, 10933,   150,  2608,   146,    19,  1838,  1781,  3418,\n",
       "         8083,   182,    15,     3,  1071,    50,  4052,   835,   308,\n",
       "         2427, 32086,  2479,  3731,    15,     3,  2840,  1842,  3221,\n",
       "           15,     3,  2526,  6646,  2479, 22356,   610,   308,    80,\n",
       "          702,  1041, 32085, 30574,  2187,  1246,  3321,  3955,   228,\n",
       "        12940, 10120,  7440,   326,   840,  1412, 30574,  2187,  1246,\n",
       "         3321,  3955,   228, 12940, 10120,  7440,   326,   571,  1702,\n",
       "          402,   249,  1296,   997,  3751, 32084, 12379, 29221,  9086,\n",
       "          285,  7539, 32083, 16506, 32082,    15, 26472,    19,   272,\n",
       "           16,  4203,  2378,   425,   244,   990,    16,   190,   990,\n",
       "         2360,    17, 10539,    19, 32081,    40, 32080,  1500,    80,\n",
       "          768,     3,   227, 15194,   686,  1439,    15,    12,  3862,\n",
       "         3462, 19702,   244,    15,     3,  2576,  3565,  1702,   402,\n",
       "          223,   124,    18, 32079, 18779,  5058, 23113,    15,     4,\n",
       "        12379, 29221,  9086, 32078, 16506,    27,    15, 26472,    15,\n",
       "            5,    40,    33,  2276,    15,     3,  4759,    65,  7641,\n",
       "         2363,    16, 32077,  4704,    32,   129,    15,     3,  7339,\n",
       "           13,    80,  7834,   160,    17,   851, 11197, 12218,  1064,\n",
       "        32076,  1959,     7,  1861, 32075,  2352,    36,   788,    41,\n",
       "          652,  2402,    46,  2995,    22,  1735,    56,  2435,    37,\n",
       "          179,  6464,    15,     3,  1151,   787,    27,   103,   863,\n",
       "        11232, 32074,  4691,    43,  3753,  2741,  2999,   179,  6464,\n",
       "           15, 32073,     1]),\n",
       " 'targets': array([32099, 11059, 32098,  1014,   241,    31,    46,   167,     7,\n",
       "         2938,    16,    43, 32097,  1576, 32096,     3, 32095,    80,\n",
       "          702,  1041,   768,     3,   227,    15, 32094,  3462, 13757,\n",
       "         1040, 32093,    18,     7,   738,    27, 11897, 32092,    41,\n",
       "        32091,  1090,    32,  2886, 32090, 25677,    15, 32089,    21,\n",
       "          560, 32088,   124, 32087,    18, 32086,  1500,   702, 32085,\n",
       "          840,  1412, 32084,   501,    15,     6, 32083,  2141, 32082,\n",
       "           15,     6,    46, 32081,  2562, 32080,     3, 32079,     7,\n",
       "          738,    27,   395,  1145, 13144,   124,  1008,    27,  5553,\n",
       "          898,  4759, 32078,   285,  7539,  2141, 32077, 29381, 11197,\n",
       "        12218,  1064,   179,  6464,    17, 32076,    19, 32075,    55,\n",
       "         8591,    47, 15821, 32074, 15821,   352, 32073,     3,     1])}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
